package cn.doitedu.ml.util

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DataTypes, StructField, StructType}

object ArraySumUDAF extends UserDefinedAggregateFunction{
  // 函数的输入参数，有几个字段，分别是什么类型
  // 比如：arr_sum(flag,other) 需要两个字段
  override def inputSchema: StructType = {
    new StructType().add("flag",DataTypes.createArrayType(DataTypes.DoubleType))
  }

  // buffer 是在聚合函数运算过程中，用于存储局部聚合结果的缓存
  override def bufferSchema: StructType = new StructType().add("buffer",DataTypes.createArrayType(DataTypes.DoubleType))

    // 最后返回结果的数据类型，在本需求中，还是一个Double数组
    override def dataType: DataType = DataTypes.createArrayType(DataTypes.DoubleType)

    // 我们的聚合运算逻辑，是否总是能返回确定结果！
    override def deterministic: Boolean = true

    // 对buffer进行初始化，在本需求逻辑中，可以先给一个长度为0的空double数组
    override def initialize(buffer: MutableAggregationBuffer): Unit = buffer.update(0,Array.emptyDoubleArray)

    // 此方法，就是局部聚合的逻辑所在地，大的逻辑就是，根据输入的一行数据input,来更新局部缓存buffer中的数据
    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {

      // 从输入的行中，取出flag数组字段
      val inputArr = input.getSeq[Double](0)

      var bufferArr = buffer.getSeq[Double](0)
      // 如果是第一次对buffer做更新操作，那么buffer中的缓存数组应该长度为0，则给他换成跟输入数组长度一致的数组
      if(bufferArr.size<1) bufferArr = Array.fill(inputArr.size)(0.0)

      // 然后，将输入的数组中各个元素按对应位置累加到buffer的数组中
      bufferArr = inputArr.zip(bufferArr).map(tp=>tp._1 + tp._2)

      // 将局部聚合结果，更新到buffer中
      buffer.update(0,bufferArr)

  }

  // 全局聚合逻辑所在地：它是将各个partition的局部聚合结果，一条一条往buffer上累加
  // buffer2代表的是每一个局部聚合结果；  buffer1代表的是本次聚合的存储所在地
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {

    var accuArr = buffer1.getSeq[Double](0)
    val inputArr = buffer2.getSeq[Double](0)

    // 如果是第一次对buffer1做更新操作，那么buffer1中的缓存数组应该长度为0，则给他换成跟输入数组长度一致的数组
    if(accuArr.size<1) accuArr = Array.fill(inputArr.size)(0.0)

    // 然后，将输入的数组中各个元素按对应位置累加到buffer的数组中
    accuArr = inputArr.zip(accuArr).map(tp=>tp._1 + tp._2)

    // 将聚合结果，更新到buffer1中
    buffer1.update(0,accuArr)
  }

  // 最后向外部返回结果的方法，这个方法中的buffer缓存，就是merge方法中的buffer1缓存
  override def evaluate(buffer: Row): Any = {
      buffer.getSeq[Double](0)
  }
}
